{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.keras.layers import Dense, Flatten \n",
    "from tensorflow.keras import Model\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Download a dataset\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "# Batch and shuffle the data\n",
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train.astype('float32') / 255, y_train)).shuffle(1024).batch(32)\n",
    "\n",
    "test_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_test[y_test==2].astype('float32') / 255, y_test[y_test==2])).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VFLPassiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLPassiveModel, self).__init__()\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(10, name=\"dense1\")\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.flatten(x)\n",
    "        return self.d1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VFLActiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLActiveModel, self).__init__()\n",
    "        self.added = tf.keras.layers.Add()\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.added(x)\n",
    "        return tf.keras.layers.Softmax()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.405218243598938, Accuracy: 88.75, Test Loss: 0.4513963758945465, Test Accuracy: 87.30619812011719\n",
      "Epoch 2, Loss: 0.2906039357185364, Accuracy: 91.85832977294922, Test Loss: 0.4250626266002655, Test Accuracy: 87.8875961303711\n",
      "Epoch 3, Loss: 0.27544891834259033, Accuracy: 92.40166473388672, Test Loss: 0.459330677986145, Test Accuracy: 87.4030990600586\n",
      "Epoch 4, Loss: 0.26715561747550964, Accuracy: 92.53500366210938, Test Loss: 0.4625568687915802, Test Accuracy: 87.1124038696289\n",
      "Epoch 5, Loss: 0.2622341513633728, Accuracy: 92.7066650390625, Test Loss: 0.5050548315048218, Test Accuracy: 86.33721160888672\n"
     ]
    }
   ],
   "source": [
    "passive_model = VFLPassiveModel()\n",
    "active_model = VFLActiveModel()\n",
    "\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    # For each batch of images and labels\n",
    "    for images, labels in train_ds:\n",
    "        with tf.GradientTape() as passive_tape:\n",
    "            # passive_model sends passive_output to active_model\n",
    "            passive_output = passive_model(images)\n",
    "            with tf.GradientTape() as active_tape:\n",
    "                active_tape.watch(passive_output)\n",
    "                active_output = active_model([passive_output, passive_output])\n",
    "                loss = loss_object(labels, active_output)\n",
    "            # active_model sends passive_output_gradients back to passive_model\n",
    "            passive_output_gradients = active_tape.gradient(loss, passive_output)\n",
    "            #print(passive_output_gradients)\n",
    "            passive_output_lost = tf.multiply(passive_output, passive_output_gradients.numpy())\n",
    "        passive_weight_gradients = passive_tape.gradient(passive_output_lost, passive_model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(passive_weight_gradients, passive_model.trainable_variables))\n",
    "\n",
    "        train_loss(loss)\n",
    "        train_accuracy(labels, active_output)\n",
    "\n",
    "    for test_images, test_labels in test_ds:\n",
    "        passive_output = passive_model(test_images)\n",
    "        active_output = active_model([passive_output, passive_output])\n",
    "        t_loss = loss_object(test_labels, active_output)\n",
    "\n",
    "        test_loss(t_loss)\n",
    "        test_accuracy(test_labels, active_output)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_possioned_gradients(passive_output_gradients, passive_output, N, class1, class2, alpha = 1.0):\n",
    "    passive_output_gradients -= passive_output\n",
    "    #print(passive_output_gradients)\n",
    "    attack_mat = np.eye(N, dtype='float32')\n",
    "    attack_mat[:, class2] += attack_mat[:, class1]*alpha\n",
    "    attack_mat[:, class1] -= attack_mat[:, class1]*alpha\n",
    "    #print(attack_mat)\n",
    "    passive_output_gradients = tf.matmul(passive_output_gradients, attack_mat)\n",
    "    passive_output_gradients += passive_output\n",
    "    return passive_output_gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.5872719287872314, Accuracy: 81.26000213623047, Test Loss: 1.1180531978607178, Test Accuracy: 45.54263687133789\n",
      "Epoch 2, Loss: 0.47115832567214966, Accuracy: 83.81166076660156, Test Loss: 1.2937822341918945, Test Accuracy: 33.81782913208008\n",
      "Epoch 3, Loss: 0.47255071997642517, Accuracy: 83.44166564941406, Test Loss: 1.4451520442962646, Test Accuracy: 27.03488540649414\n",
      "Epoch 4, Loss: 0.4741647243499756, Accuracy: 83.53666687011719, Test Loss: 1.561234474182129, Test Accuracy: 23.643409729003906\n",
      "Epoch 5, Loss: 0.475004643201828, Accuracy: 83.59000396728516, Test Loss: 1.557436466217041, Test Accuracy: 25.387596130371094\n"
     ]
    }
   ],
   "source": [
    "passive_model = VFLPassiveModel()\n",
    "active_model = VFLActiveModel()\n",
    "\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    # For each batch of images and labels\n",
    "    for images, labels in train_ds:\n",
    "        with tf.GradientTape() as passive_tape:\n",
    "            # passive_model sends passive_output to active_model\n",
    "            passive_output = passive_model(images)\n",
    "            with tf.GradientTape() as active_tape:\n",
    "                active_tape.watch(passive_output)\n",
    "                active_output = active_model([passive_output, passive_output])\n",
    "                loss = loss_object(labels, active_output)\n",
    "            # active_model sends passive_output_gradients back to passive_model\n",
    "            passive_output_gradients = active_tape.gradient(loss, passive_output)\n",
    "            passive_output_gradients = \\\n",
    "            get_possioned_gradients(passive_output_gradients, passive_output, 10, 2, 4)\n",
    "            passive_output_loss = tf.multiply(passive_output, passive_output_gradients.numpy())\n",
    "        passive_weight_gradients = passive_tape.gradient(passive_output_loss, passive_model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(passive_weight_gradients, passive_model.trainable_variables))\n",
    "\n",
    "        train_loss(loss)\n",
    "        train_accuracy(labels, active_output)\n",
    "\n",
    "    for test_images, test_labels in test_ds:\n",
    "        passive_output = passive_model(test_images)\n",
    "        active_output = active_model([passive_output, passive_output])\n",
    "        t_loss = loss_object(test_labels, active_output)\n",
    "\n",
    "        test_loss(t_loss)\n",
    "        test_accuracy(test_labels, active_output)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 14.362151    9.        224.04999    47.196995  522.3977      0.637849\n",
      "  22.          6.9999685 178.35532     7.0000315]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAANyklEQVR4nO3dXYxdV3nG8f/TGAiEggOZRKltOqmwKFElSDRK3Uaq2hhV+UA4F4kU1CZW5Mo3oQ0FiRpuqkq9MFJFAKmKZCW0pqVAFECxkogS5UNVL5IyTtJ8YFCmaRpP7eKhJIEWUZry9mKW28E+9hyP58yJ1/n/pNHZ+91rzn63P55Zs+acPakqJEl9+blxNyBJWn2GuyR1yHCXpA4Z7pLUIcNdkjq0btwNAJx33nk1PT097jYk6Yyyf//+71XV1KBjr4lwn56eZnZ2dtxtSNIZJcm/nOiYyzKS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktSh18Q7VKVTMb3rvpE+/wu7rxnp80trwZm7JHXIcJekDhnuktQhw12SOjRUuCd5IcnTSZ5MMttqb0vyQJLn2uO5rZ4kn00yl+SpJJeO8gIkScc7lZn7b1XVe6tqpu3vAh6sqs3Ag20f4Cpgc/vYCdy+Ws1KkoZzOssy24C9bXsvcO2S+udr0aPA+iQXnsZ5JEmnaNhwL+AbSfYn2dlqF1TVYYD2eH6rbwAOLvnc+VaTJK2RYd/EdHlVHUpyPvBAkm+fZGwG1Oq4QYtfJHYCvOMd7xiyDUnSMIaauVfVofZ4BPgacBnw3aPLLe3xSBs+D2xa8ukbgUMDnnNPVc1U1czU1MDf7ypJWqFlwz3JOUl+/ug28NvAM8A+YHsbth24p23vA25qr5rZArxydPlGkrQ2hlmWuQD4WpKj4/+mqr6e5JvAXUl2AC8C17fx9wNXA3PAj4CbV71rSdJJLRvuVfU88J4B9X8Htg6oF3DLqnQnSVoR36EqSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHhg73JGcleSLJvW3/oiSPJXkuyZeTvL7V39D259rx6dG0Lkk6kVOZud8KHFiy/0ngtqraDLwE7Gj1HcBLVfVO4LY2TpK0hoYK9yQbgWuAO9p+gCuAu9uQvcC1bXtb26cd39rGS5LWyLAz908DHwN+2vbfDrxcVa+2/XlgQ9veABwEaMdfaeN/RpKdSWaTzC4sLKywfUnSIMuGe5L3A0eqav/S8oChNcSx/y9U7amqmaqamZqaGqpZSdJw1g0x5nLgA0muBs4G3sLiTH59knVtdr4RONTGzwObgPkk64C3At9f9c4lSSe07My9qj5eVRurahq4AXioqn4HeBi4rg3bDtzTtve1fdrxh6rquJm7JGl0Tud17n8EfCTJHItr6ne2+p3A21v9I8Cu02tRknSqhlmW+T9V9QjwSNt+HrhswJgfA9evQm+SpBXyHaqS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo2XBPcnaSf0jyj0meTfInrX5RkseSPJfky0le3+pvaPtz7fj0aC9BknSsYWbu/wVcUVXvAd4LXJlkC/BJ4Laq2gy8BOxo43cAL1XVO4Hb2jhJ0hpaNtxr0X+03de1jwKuAO5u9b3AtW17W9unHd+aJKvWsSRpWUOtuSc5K8mTwBHgAeCfgJer6tU2ZB7Y0LY3AAcB2vFXgLcPeM6dSWaTzC4sLJzeVUiSfsZQ4V5V/1NV7wU2ApcB7x40rD0OmqXXcYWqPVU1U1UzU1NTw/YrSRrCKb1apqpeBh4BtgDrk6xrhzYCh9r2PLAJoB1/K/D91WhWkjScYV4tM5Vkfdt+I/A+4ADwMHBdG7YduKdt72v7tOMPVdVxM3dJ0uisW34IFwJ7k5zF4heDu6rq3iTfAr6U5E+BJ4A72/g7gb9KMsfijP2GEfQtSTqJZcO9qp4CLhlQf57F9fdj6z8Grl+V7iRJK+I7VCWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR1aN+4GtDLTu+4b+Tle2H3NyM8haTScuUtShwx3SeqQ4S5JHTLcJalDhrskdWjZcE+yKcnDSQ4keTbJra3+tiQPJHmuPZ7b6kny2SRzSZ5KcumoL0KS9LOGmbm/Cny0qt4NbAFuSXIxsAt4sKo2Aw+2fYCrgM3tYydw+6p3LUk6qWXDvaoOV9XjbfuHwAFgA7AN2NuG7QWubdvbgM/XokeB9UkuXPXOJUkndEpr7kmmgUuAx4ALquowLH4BAM5vwzYAB5d82nyrSZLWyNDhnuTNwFeAD1fVD042dECtBjzfziSzSWYXFhaGbUOSNIShwj3J61gM9i9U1Vdb+btHl1va45FWnwc2Lfn0jcChY5+zqvZU1UxVzUxNTa20f0nSAMO8WibAncCBqvrUkkP7gO1teztwz5L6Te1VM1uAV44u30iS1sYwNw67HLgReDrJk632CWA3cFeSHcCLwPXt2P3A1cAc8CPg5lXtWJK0rGXDvar+nsHr6ABbB4wv4JbT7EuSdBp8h6okdchwl6QOGe6S1CF/E5Ok1zR/69jKOHOXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjq0bLgn+VySI0meWVJ7W5IHkjzXHs9t9ST5bJK5JE8luXSUzUuSBhtm5v6XwJXH1HYBD1bVZuDBtg9wFbC5fewEbl+dNiVJp2LZcK+qvwO+f0x5G7C3be8Frl1S/3wtehRYn+TC1WpWkjScla65X1BVhwHa4/mtvgE4uGTcfKsdJ8nOJLNJZhcWFlbYhiRpkNX+gWoG1GrQwKraU1UzVTUzNTW1ym1I0mRbabh/9+hyS3s80urzwKYl4zYCh1beniRpJVYa7vuA7W17O3DPkvpN7VUzW4BXji7fSJLWzrrlBiT5IvCbwHlJ5oE/BnYDdyXZAbwIXN+G3w9cDcwBPwJuHkHPkqRlLBvuVfXBExzaOmBsAbecblOSpNPjO1QlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nq0LL3c5eONb3rvpGf44Xd14z8HFLPnLlLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHfIeqdIYY9TuDfVdwX5y5S1KHDHdJ6tAZvyzjTawk6XjO3CWpQ2f8zH2c/K5B0mvVSMI9yZXAZ4CzgDuqavcoziNpbUzqROZMvu5VX5ZJchbw58BVwMXAB5NcvNrnkSSd2CjW3C8D5qrq+ar6CfAlYNsIziNJOoFU1eo+YXIdcGVV/V7bvxH41ar60DHjdgI72+67gO+saiMndx7wvTU832uF1z1ZvO7+/WJVTQ06MIo19wyoHfcVpKr2AHtGcP5lJZmtqplxnHucvO7J4nVPtlEsy8wDm5bsbwQOjeA8kqQTGEW4fxPYnOSiJK8HbgD2jeA8kqQTWPVlmap6NcmHgL9l8aWQn6uqZ1f7PKdpLMtBrwFe92TxuifYqv9AVZI0ft5+QJI6ZLhLUocmKtyTXJnkO0nmkuwadz9rIcmmJA8nOZDk2SS3jruntZTkrCRPJLl33L2spSTrk9yd5Nvt7/7Xxt3TWkjyh+3f+TNJvpjk7HH3NC4TE+4TfFuEV4GPVtW7gS3ALRNy3UfdChwYdxNj8Bng61X1y8B7mIA/gyQbgD8AZqrqV1h8QccN4+1qfCYm3JnQ2yJU1eGqerxt/5DF/+QbxtvV2kiyEbgGuGPcvaylJG8BfgO4E6CqflJVL4+3qzWzDnhjknXAm5jg99hMUrhvAA4u2Z9nQkLuqCTTwCXAY+PtZM18GvgY8NNxN7LGfglYAP6iLUndkeSccTc1alX1r8CfAS8Ch4FXquob4+1qfCYp3Ie6LUKvkrwZ+Arw4ar6wbj7GbUk7weOVNX+cfcyBuuAS4Hbq+oS4D+B7n/GlORcFr8bvwj4BeCcJL873q7GZ5LCfWJvi5DkdSwG+xeq6qvj7meNXA58IMkLLC7BXZHkr8fb0pqZB+ar6uh3aHezGPa9ex/wz1W1UFX/DXwV+PUx9zQ2kxTuE3lbhCRhce31QFV9atz9rJWq+nhVbayqaRb/rh+qqomYxVXVvwEHk7yrlbYC3xpjS2vlRWBLkje1f/dbmYAfJJ/IxPyavTPktgijcDlwI/B0kidb7RNVdf8Ye9Lo/T7whTaReR64ecz9jFxVPZbkbuBxFl8l9gQTfCsCbz8gSR2apGUZSZoYhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nq0P8C8AxTMY129fIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "x_val = x_test[y_test==2]\n",
    "y_val = y_test[y_test==2]\n",
    "\n",
    "passive_output = passive_model(x_val)\n",
    "active_output = active_model([passive_output, passive_output])\n",
    "\n",
    "output_distribution = np.sum(active_output, axis=0)\n",
    "print(output_distribution)\n",
    "\n",
    "n = 10\n",
    "X = np.arange(n)\n",
    "plt.bar(X, output_distribution)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n",
      "[[0.5301766  0.36834355 0.0751616  0.34628156 0.67369025 0.71920954\n",
      "  0.39001503 0.05330744 0.60812417 0.57757241]\n",
      " [0.11696672 0.91286319 0.76614699 0.18240779 0.34124305 0.64034105\n",
      "  0.14179049 0.5599264  0.92628808 0.04293072]\n",
      " [0.8208049  0.58776502 0.82512238 0.75042469 0.143793   0.9864077\n",
      "  0.28055798 0.66674831 0.15096047 0.8785071 ]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.5301766 , 0.36834355, 0.        , 0.34628156, 0.74885185,\n",
       "        0.71920954, 0.39001503, 0.05330744, 0.60812417, 0.57757241],\n",
       "       [0.11696672, 0.91286319, 0.        , 0.18240779, 1.10739004,\n",
       "        0.64034105, 0.14179049, 0.5599264 , 0.92628808, 0.04293072],\n",
       "       [0.8208049 , 0.58776502, 0.        , 0.75042469, 0.96891539,\n",
       "        0.9864077 , 0.28055798, 0.66674831, 0.15096047, 0.8785071 ]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = 10\n",
    "class1 = 2\n",
    "class2 = 4\n",
    "alpha = 1.0\n",
    "\n",
    "attack_mat = np.eye(N, dtype='float32')\n",
    "attack_mat[:, class2] += attack_mat[:, class1]*alpha\n",
    "attack_mat[:, class1] -= attack_mat[:, class1]*alpha\n",
    "\n",
    "a = np.random.rand(3, N)\n",
    "print(attack_mat)\n",
    "b = np.matmul(a, attack_mat)\n",
    "print(a)\n",
    "b"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
